import torch
import torchvision
from torch import nn
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import transforms


# 全连接层算法验证
def linear_test():
    input = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
    linear = nn.Linear(in_features=3, out_features=5)
    output = linear(input)
    print(output)


if __name__ == '__main__':
    linear_test()
